package com.chatplus.application.aiprocessor.interceptor;

import cn.bugstack.openai.exception.OpenAiSdkException;
import cn.bugstack.openai.executor.interceptor.KeyStrategyFunction;
import cn.bugstack.openai.executor.model.baidu.utils.AccessTokenUtils;
import cn.bugstack.openai.executor.model.chatglm.utils.BearerTokenUtils;
import cn.bugstack.openai.executor.model.xunfei.config.XunFeiConfig;
import cn.bugstack.openai.executor.model.xunfei.utils.URLAuthUtils;
import cn.bugstack.openai.session.Configuration;
import cn.bugstack.openai.session.defaults.SwitchProxySelector;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Validator;
import cn.hutool.extra.spring.SpringUtil;
import cn.hutool.http.Header;
import cn.hutool.http.HttpUtil;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.domain.dto.ApiKeyDto;
import com.chatplus.application.domain.dto.DevopsConfigDto;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.util.ConfigUtil;
import com.chatplus.application.util.email.MailUtils;
import lombok.Getter;
import lombok.Setter;
import okhttp3.*;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.redisson.api.RBucket;
import org.redisson.api.RedissonClient;

import java.text.MessageFormat;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 动态key处理拦截器
 *
 * @author chj
 * @date 2024/3/29
 **/
@Setter
@Getter
public class DynamicKeyHandleInterceptor implements Interceptor {

    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(DynamicKeyHandleInterceptor.class);
    /**
     * key 集合
     */
    private List<ApiKeyDto> apiKeyList;

    private AiPlatformEnum platformChannel;

    public DynamicKeyHandleInterceptor(List<ApiKeyDto> apiKeyList, AiPlatformEnum platformChannel) {
        this.apiKeyList = apiKeyList;
        this.platformChannel = platformChannel;
    }

    /**
     * 自定义的key的使用策略
     */
    private KeyStrategyFunction<List<ApiKeyDto>, ApiKeyDto> keyStrategy = new KeyRandomStrategy();


    /**
     * 所有的key都失效后，自定义预警配置
     */
    protected void noHaveActiveKeyWarring() {
        LOGGER.message("--------> [告警] 没有可用的key！！！").error();
    }

    /**
     * 获取请求key
     *
     * @return key
     */
    public final ApiKeyDto getKey() {
        if (CollUtil.isEmpty(apiKeyList)) {
            this.noHaveActiveKeyWarring();
            return null;
        }
        return keyStrategy.apply(apiKeyList);
    }

    private Request auth(ApiKeyDto apiKeyDto, Chain chain) throws Exception {
        String proxy = apiKeyDto.getProxyUrl();
        String key = apiKeyDto.getApiKey();
        String apiHost = apiKeyDto.getUrl();
        Request original = chain.request();
        Request.Builder builder = chain.request().newBuilder();
        if (StringUtils.isNotEmpty(proxy)) {
            LOGGER.message("设置代理").context("proxyHeader", proxy).info();
            SwitchProxySelector.proxyThreadLocal.set(SwitchProxySelector.getProxy(proxy));
        }
        builder.url(original.url())
                .header("User-Agent", Configuration.DEFAULT_USER_AGENT)
                .header("Accept", Configuration.SSE_CONTENT_TYPE)
                .method(original.method(), original.body());
        switch (platformChannel) {
            case OPEN_AI, ALI_YUN_Q_WEN:
                builder.url(apiHost);
                builder.header(Header.AUTHORIZATION.getValue(), MessageFormat.format("Bearer {0}", key));
                builder.header("Content-Type", Configuration.JSON_CONTENT_TYPE);
                break;
            case CHAT_GLM:
                builder.url(apiHost);
                String[] arrStr = key.split("\\.");
                builder.header(Header.AUTHORIZATION.getValue(), MessageFormat.format("Bearer {0}", BearerTokenUtils.getToken(arrStr[0], arrStr[1])));
                break;
            case XUN_FEI:
                String apiId = key.split("\\|")[0];
                String apiKey = key.split("\\|")[1];
                String secretKey = key.split("\\|")[2];
                builder.url(URLAuthUtils.getAuthURl(apiHost, apiKey, secretKey));
                builder.header(XunFeiConfig.XUN_FEI_APP_ID, apiId);
                break;
            case BAI_DU:
                String accessToken = AccessTokenUtils.getAccessToken(key.split("\\|")[0], key.split("\\|")[1]);
                if (StringUtils.isEmpty(accessToken)) {
                    throw new OpenAiSdkException("获取百度Token失败");
                }
                builder.url(apiHost.concat("?access_token=").concat(accessToken));
                break;
            default:
                break;
        }
        return builder.build();
    }

    protected List<ApiKeyDto> onErrorDealApiKeys(String apiKey) {
        apiKeyList = getApiKeyList().stream().filter(e -> !apiKey.equals(e.getApiKey())).collect(Collectors.toList());
        LOGGER.message("--------> 当前ApiKey失效了，移除！").context("errorKey", apiKey).error();
        // 可以做发送告警处理等操作
        sendAlert(apiKey);
        return apiKeyList;
    }

    private final RedissonClient redissonClient = SpringUtil.getBean(RedissonClient.class);

    private void sendAlert(String apiKey) {
        try {
            DevopsConfigDto devopsConfigDto = ConfigUtil.getDevopsConfig();
            if (Boolean.TRUE.equals(devopsConfigDto.getEnableAlert())) {
                String alertEmail = devopsConfigDto.getAlertMail();
                String alertWebHook = devopsConfigDto.getAlertWebhook();
                int alertInterval = devopsConfigDto.getAlertInterval() != null ? devopsConfigDto.getAlertInterval() : 10;
                RBucket<String> rBucket = redissonClient.getBucket(String.format("alert:apiKey:%s", apiKey));
                if (rBucket.isExists()) {
                    return;
                }
                String alertMsg = String.format("平台【%s】，失效的ApiKey【%s】,当前可用key数量【%s】", platformChannel.getName(), apiKey, apiKeyList.size());
                // 发送邮件
                if (StringUtils.isNotEmpty(alertEmail) && Validator.isEmail(alertEmail)) {
                    MailUtils.sendText(alertEmail, "ApiKey失效告警", alertMsg);
                }
                // 发送webhook
                if (StringUtils.isNotEmpty(alertWebHook)) {
                    HttpUtil.post(alertWebHook, Map.of("msg", alertMsg));
                }
                rBucket.set("1", Duration.ofMinutes(alertInterval));
                LOGGER.message("--------> 发送告警成功").context("errorKey", apiKey).context("alertEmail", alertEmail).context("alertWebHook", alertWebHook).info();
            }
        } catch (Exception e) {
            LOGGER.message("--------> 发送告警失败").context("errorKey", apiKey).exception(e).error();
        }
    }

    @NotNull
    @Override
    public Response intercept(@NotNull Chain chain) {
        return retry(chain);
    }

    Response retry(Chain chain) {
        Response response = new Response.Builder()
                .request(chain.request())
                .body(ResponseBody.create("ApiKey失效", MediaType.get("application/json")))
                .protocol(Protocol.HTTP_2)
                .code(500).message("ApiKey失效")
                .build();
        ApiKeyDto apiKeyDto = getKey();
        if (apiKeyDto == null) {
            return response;
        }
        String key = apiKeyDto.getApiKey();
        try {
            Request request = this.auth(apiKeyDto, chain);
            response = chain.proceed(request);
            String upgrade = response.header("Upgrade");
            boolean isFail = !response.isSuccessful();
            // 但如果是讯飞的websocket请求，返回的code是101，表示成功
            if (platformChannel == AiPlatformEnum.XUN_FEI && "websocket".equals(upgrade) && response.code() == 101) {
                isFail = false;
            }
            if (isFail) {
                String errorMsg = "未知错误";
                if (response.body() != null) {
                    errorMsg = response.body().string();
                }
                LOGGER.message("--------> 请求失败！").context("key", key).context("errorMsg", errorMsg).error();
                setApiKeyList(this.onErrorDealApiKeys(key));
                // 如果都没有可用的key，直接返回,百度先不管
                if (CollUtil.isEmpty(apiKeyList) || platformChannel == AiPlatformEnum.BAI_DU) {
                    return response;
                }
                // 到这步说明会重试，则把当前的response关闭
                return retry(chain);
            }
        } catch (
                Exception e) {
            LOGGER.message("--------> 请求失败！").context("key", key).exception(e).error();
            setApiKeyList(this.onErrorDealApiKeys(key));
            return retry(chain);
        }
        return response;
    }
}
